import os
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
import numpy as np
import torch
import random
import lpips
import torchvision.transforms as TF
from PIL import Image

try:
    from tqdm import tqdm
except ImportError:
    # If tqdm is not available, provide a mock version of it
    def tqdm(x):
        return x

from image_synthesis.utils.misc import get_all_file
from image_synthesis.utils.io import save_dict_to_json

parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument('--batch-size', type=int, default=50,
                    help='Batch size to use')
parser.add_argument('--num-workers', type=int, default=8,
                    help='Number of processes to use for data loading')
parser.add_argument('--device', type=str, default=None,
                    help='Device to use. Like cuda, cuda:0 or cpu')
parser.add_argument('--path', type=str, default='',
                    help='Paths to the generated images')
parser.add_argument('--count', type=int, default=5000,
                    help='count of images in path1 to be coumputed') 
parser.add_argument('--im_size', type=str, default='256',
                    help='count of images in path1 to be coumputed')    
parser.add_argument('--loops', type=int, default=2500,
                    help='count of images in path1 to be coumputed')            
parser.add_argument('--net', type=str, default='vgg', 
                    choices=['vgg', 'alex'],
                    help='which type of network to use?')  


IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm',
                    'tif', 'tiff', 'webp', 'JPEG'}


class ImagePathDataset(torch.utils.data.Dataset):
    def __init__(self, files, transforms=None, count=None):
        self.files = files
        self.transforms = transforms
        self.count = count

    def __len__(self):
        if self.count is not None:
            return min(self.count, len(self.files))
        return len(self.files)

    def __getitem__(self, i):
        path = self.files[i]
        img = Image.open(path).convert('RGB')
        if self.transforms is not None:
            img = self.transforms(img)
        img = (img-0.5) / 0.5
        return img


def get_dataset(path, count=None, im_size=None):
    files = get_all_file(path, end_with=IMAGE_EXTENSIONS)
    random.shuffle(files)
    if count is not None:
        files = files[:count]
    
    print('Evaluate {} images'.format(len(files)))
    if im_size is not None:
        transforms = TF.Compose([
            TF.Resize(im_size),
            TF.CenterCrop(im_size),
            TF.ToTensor()
        ])
    else:
        transforms = TF.ToTensor()

    dataset = ImagePathDataset(files, transforms=transforms)
    return dataset


def calculate_lpips_value(path, device, net, loops, count=None, im_size=None):
    dataset = get_dataset(path, count, im_size)
    loss = lpips.LPIPS(net=net).to(device)

    processed_pair = set([])
    im_index = list(range(len(dataset)))

    value = []
    with torch.no_grad():
        for l in tqdm(range(loops)):
            # print('{}/{}'.format(l, loops))
            two_idx = tuple(sorted(random.sample(im_index, 2)))
            while two_idx in processed_pair:
                two_idx = tuple(sorted(random.sample(im_index, 2)))

            processed_pair.add(two_idx)

            im1 = dataset[two_idx[0]].unsqueeze(dim=0).to(device)
            im2 = dataset[two_idx[1]].unsqueeze(dim=0).to(device)

            v = loss(im1, im2).view(im1.shape[0]).to('cpu')
            value.append(v)
    
    value = torch.cat(value, dim=0)
    m = value.mean()
    std = value.std()
    
    statics = {
        'mean': m.item(),
        'std': std.item()
    }
    save_path = os.path.join(path, 'lpips_score_count{}_loops{}.json'.format(count, loops))
    save_dict_to_json(statics, save_path)
    print('saved to {}'.format(save_path))
    return statics

def main():
    args = parser.parse_args()
    if args.im_size is not None:
        args.im_size = [int(s) for s in args.im_size.split(',')]
        if len(args.im_size) == 1:
            args.im_size1 = args.im_size[0]     


    if args.device is None:
        device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
    else:
        device = torch.device(args.device)

    statics = calculate_lpips_value(path=args.path,
                                        device=device,
                                        net=args.net,
                                        loops=args.loops,
                                        count=args.count,
                                        im_size=args.im_size)
    print('statics: ', statics)


if __name__ == '__main__':

    main()